/*
 * Copyright (c) 2020 Huawei Technologies Co.,Ltd.
 *
 * openGauss is licensed under Mulan PSL v2.
 * You can use this software according to the terms and conditions of the Mulan PSL v2.
 * You may obtain a copy of Mulan PSL v2 at:
 *
 *          http://license.coscl.org.cn/MulanPSL2
 *
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 * See the Mulan PSL v2 for more details.
 * -------------------------------------------------------------------------
 *
 * security_gs_ktool.cpp
 *      gs_ktool is an independent key management tool provided by GaussDB Kernel, can generate and store symmetric
 *      key with [16, 112] bytes.
 *      when CREATE CMKO, if KEY_STROE = gs_ktool, then:
 *          1. KEY_PATH: gs_ktool use $key_id to identify keys, KEY_PATH = "gs_ktool/$key_id"
 *          2. ALGORITHM: gs_ktool cannot generate asymmetric key pairs, so keys generated by gs_ktool are only
 *          available for AES_256 algorithm
 *      if you register gs_ktool, you should be sure your system has installed gs_ktool, and the environment variables
 *      and the configuration files are available.
 *
 * IDENTIFICATION
 *	  src/gausskernel/security/keymgr/src/ktool/security_gs_ktool.cpp
 *
 * -------------------------------------------------------------------------
 */
#include "keymgr/ktool/security_gs_ktool.h"
#include <stdio.h>
#include <string.h>
#ifdef ENABLE_KT
#include "gs_ktool/kt_interface.h"
#endif
#include "keymgr/security_key_mgr.h"
#include "keymgr/comm/security_error.h"
#include "keymgr/localkms/security_cmkem_comm_algorithm.h"

const int MAX_KEYPATH_LEN = 64;
const int UPPER_TO_LOWER_OFFSET = 'a' - 'A';

static const char *g_support_algo[] = {"AES_256_CBC", "SM4", "AES_256_GCM", NULL};

static CmkemErrCode check_cmk_id_validity(KeyInfo info);
static CmkemErrCode check_cmk_algo_validity(const char *cmk_algo);
static CmkemErrCode check_cmk_entity_validity(KeyInfo info);
static void get_cmk_id_from_key_path(const char *key_path, unsigned int *cmk_id);
static CmkemErrCode read_cmk_plain(GsKtoolMgr *ktool, unsigned int cmk_id, CmkemUStr **cmk_plain);

CmkCacheList *init_cmk_cache_list();
void push_cmk_to_cache(GsKtoolMgr *ktool, unsigned int cmk_id, const unsigned char *cmk_plian);
bool get_cmk_from_cache(GsKtoolMgr *ktool, unsigned int cmk_id, unsigned char *cmk_plain);
void free_cmk_cache_list(CmkCacheList *cmk_cahce_list);

void cmkem_tolower(const char *in, char *out, size_t out_buf_len)
{
    if (out_buf_len < strlen(in) + 1) {
        return;
    }

    size_t i = 0;
    for (; i < strlen(in); i++) {
        if (in[i] >= 'A' && in[i] <= 'Z') {
            out[i] = in[i] + UPPER_TO_LOWER_OFFSET;
        } else {
            out[i] = in[i];
        }
    }
    out[i] = '\0';
}

static CmkemErrCode check_cmk_id_validity(KeyInfo info)
{
    const char *key_path_tag = "gs_ktool/";
    char tmp_str[MAX_KEYPATH_LEN] = {0};
    int tmp_pos = 0;
    bool has_invalid_char = false;

    if (strlen(info.id) <= strlen(key_path_tag)) {
        cmkem_errmsg("invalid key path: '%s', it should be like \"%s1\".", info.id, key_path_tag);
        return CMKEM_CHECK_CMK_ID_ERR;
    }

    char cmk_id_lower[strlen(info.id) + 1] = {0};
    cmkem_tolower(info.id, cmk_id_lower, sizeof(cmk_id_lower));
    for (size_t i = 0; i < strlen(key_path_tag); i++) {
        if (cmk_id_lower[i] != key_path_tag[i]) {
            cmkem_errmsg("invalid key path: '%s', it should be like \"%s1\".", cmk_id_lower, key_path_tag);
            return CMKEM_CHECK_CMK_ID_ERR;
        }
    }

    for (size_t i = strlen(key_path_tag); i < strlen(info.id); i++) {
        if (info.id[i] < '0' || info.id[i] > '9') {
            has_invalid_char = true;
        }
        tmp_str[tmp_pos] = info.id[i];
        tmp_pos++;
    }

    if (has_invalid_char) {
        cmkem_errmsg("invalid key path: '%s', '%s' is expected to be an integer.", info.id, tmp_str);
        return CMKEM_CHECK_CMK_ID_ERR;
    }

    tmp_str[tmp_pos] = '\0';

    return CMKEM_SUCCEED;
}

unsigned int get_key_id(const char *str)
{
    if (strlen(str) <= strlen("gs_ktool/")) {
        return 0;
    }
    return (unsigned int)atoi(str + strlen("gs_ktool/"));
}

static CmkemErrCode check_cmk_algo_validity(const char *cmk_algo)
{
    char error_msg_buf[MAX_CMKEM_ERRMSG_BUF_SIZE] = {0};
    error_t rc = 0;
    
    for (size_t i = 0; g_support_algo[i] != NULL; i++) {
        if (strcasecmp(cmk_algo, g_support_algo[i]) == 0) {
            return CMKEM_SUCCEED;
        }
    }

    rc = sprintf_s(error_msg_buf, MAX_CMKEM_ERRMSG_BUF_SIZE, "unpported algorithm '%s', gs_ktool only support: ",
        cmk_algo);
    km_securec_check_ss(rc, "", "");

    for (size_t i = 0; g_support_algo[i] != NULL; i++) {
        rc = strcat_s(error_msg_buf, MAX_CMKEM_ERRMSG_BUF_SIZE, g_support_algo[i]);
        km_securec_check(rc, "", "");
        rc = strcat_s(error_msg_buf, MAX_CMKEM_ERRMSG_BUF_SIZE, "  ");
        km_securec_check(rc, "", "");
    }

    cmkem_errmsg("%s", error_msg_buf);
    return CMKEM_CHECK_ALGO_ERR;
}

static CmkemErrCode check_cmk_entity_validity(KeyInfo info)
{
#ifdef ENABLE_KT
    unsigned int cmk_len = 0;
    unsigned int cmk_id = get_key_id(info.id);
    if (!get_cmk_len(cmk_id, &cmk_len)) {
        cmkem_errmsg("failed to read cmk from gs_ktool, key id: %u.", cmk_id);
        return CMKEM_GS_KTOOL_ERR;
    }

    if (cmk_len < get_key_len_by_algo(get_algo_by_str(info.algo))) {
        return CMKEM_CHECK_ALGO_ERR;
    }
#endif
    return CMKEM_SUCCEED;
}

static void get_cmk_id_from_key_path(const char *key_path, unsigned int *cmk_id)
{
    const char *key_path_tag = "gs_ktool/";
    *cmk_id = (unsigned int)atoi(key_path + strlen(key_path_tag));
}

static CmkemErrCode read_cmk_plain(GsKtoolMgr *ktool, unsigned int cmk_id, CmkemUStr **cmk_plain)
{
#ifdef ENABLE_KT
    unsigned int tmp_cmk_len = 0;

    *cmk_plain = malloc_cmkem_ustr(AES256_KEY_BUF_LEN);
    if (*cmk_plain == NULL) {
        return CMKEM_MALLOC_MEM_ERR;
    }

    /* case a : try to get cmk plain from cache */
    if (!get_cmk_from_cache(ktool, cmk_id, (*cmk_plain)->ustr_val)) {
        /* case b : failed to get cmk plian from cache, try to get it from gs_ktool */
        if (!get_cmk_plain(cmk_id, (*cmk_plain)->ustr_val, &tmp_cmk_len)) {
            free_cmkem_ustr_with_erase(*cmk_plain);
            return CMKEM_GS_KTOOL_ERR;
        }

        push_cmk_to_cache(ktool, cmk_id, (*cmk_plain)->ustr_val);
    }

    (*cmk_plain)->ustr_len = (size_t) tmp_cmk_len;
#endif
    return CMKEM_SUCCEED;
}

/* LRU cache */
CmkCacheList *init_cmk_cache_list()
{
    CmkCacheList *cmk_cache_list = NULL;
    cmk_cache_list = (CmkCacheList *)km_alloc(sizeof(CmkCacheList));
    if (cmk_cache_list == NULL) {
        cmkem_errmsg("failed to malloc memory.");
        return NULL;
    }

    cmk_cache_list->cmk_node_cnt = 0;
    cmk_cache_list->first_cmk_node = NULL;
    return cmk_cache_list;
}

void push_cmk_to_cache(GsKtoolMgr *ktool, unsigned int cmk_id, const unsigned char *cmk_plian)
{
    CmkCacheList *cache_list = ktool->cache;
    CmkCacheNode *new_node = NULL;
    CmkCacheNode *last_node = NULL;

    new_node = (CmkCacheNode *)km_alloc(sizeof(CmkCacheNode));
    if (new_node == NULL) {
        cmkem_errmsg("failed to malloc memory.");
        return;
    }
    new_node->cmk_id = cmk_id;
    for (size_t i = 0; i < DEFAULT_CMK_CACHE_LEN; i++) {
        new_node->cmk_plain[i] = cmk_plian[i];
    }
    
    if (cache_list->cmk_node_cnt < MAX_CMK_CACHE_NODE_CNT) {
        new_node->next = cache_list->first_cmk_node;
        cache_list->first_cmk_node = new_node;
        cache_list->cmk_node_cnt++;
    } else {
        last_node = cache_list->first_cmk_node;
        while (last_node->next->next != NULL) {
            last_node = last_node->next;
        }
        cmkem_free(last_node->next);
        last_node->next = NULL;

        new_node->next = cache_list->first_cmk_node;
        cache_list->first_cmk_node = new_node;
        cache_list->cmk_node_cnt++;
    }
}

bool get_cmk_from_cache(GsKtoolMgr *ktool, unsigned int cmk_id, unsigned char *cmk_plain)
{
    CmkCacheList *cache_list = ktool->cache;
    CmkCacheNode *cur_cmk_node = NULL;
    CmkCacheNode *correct_cmk_node = NULL;

    if (cache_list->first_cmk_node == NULL) {
        return false;
    }

    /* the head node is not used */
    cur_cmk_node = cache_list->first_cmk_node;
    /* a. there are only 1 node */
    if (cur_cmk_node->next == NULL) {
        if (cur_cmk_node->cmk_id == cmk_id) {
            for (size_t i = 0; i < DEFAULT_CMK_CACHE_LEN; i++) {
                cmk_plain[i] = cur_cmk_node->cmk_plain[i];
            }
            return true;
        }
    } else { /* case b : there are 2 or more nodes */
        /*
         * if the first node is the correct node, like this :
         * to find node '2', and the cache list is : '2' -> '1' -> '3'
         */
        if (cur_cmk_node->cmk_id == cmk_id) {
            for (size_t i = 0; i < DEFAULT_CMK_CACHE_LEN; i++) {
                cmk_plain[i] = cur_cmk_node->cmk_plain[i];
            }
            return true;
        } else {
            while (cur_cmk_node->next != NULL) {
                if (cur_cmk_node->next->cmk_id == cmk_id) {
                    correct_cmk_node = cur_cmk_node->next;
                    for (size_t i = 0; i < DEFAULT_CMK_CACHE_LEN; i++) {
                        cmk_plain[i] = correct_cmk_node->cmk_plain[i];
                    }

                    /* refresh cache list */
                    cur_cmk_node->next = correct_cmk_node->next;
                    correct_cmk_node->next = cache_list->first_cmk_node;
                    cache_list->first_cmk_node = correct_cmk_node;
                    return true;
                }
                
                cur_cmk_node = cur_cmk_node->next;
            }
        }
    }

    return false;
}

void free_cmk_cache_list(CmkCacheList *cmk_cahce_list)
{
    CmkCacheNode *to_free = NULL;
    CmkCacheNode *cur_node = NULL;

    if (cmk_cahce_list == NULL) {
        return;
    }

    cur_node = cmk_cahce_list->first_cmk_node;
    while (cur_node != NULL) {
        to_free = cur_node;
        cur_node = cur_node->next;
        cmkem_free(to_free);
    }

    km_free(cmk_cahce_list);
}


char *ktool_mk_select(KeyMgr *kmgr, KeyInfo info)
{
    GsKtoolMgr *kt = (GsKtoolMgr *)(void *)kmgr;
    CmkemErrCode ret = CMKEM_SUCCEED;

    if (info.id == NULL) {
        km_err_msg(kt->kmgr.err, "failed to create client master key, failed to find arg: KEY_PATH.");
        return NULL;
    }

    if (info.algo == NULL) {
        km_err_msg(kt->kmgr.err, "failed to create client master key, failed to find arg: ALGORITHM.");
        return NULL;
    }

    ret = check_cmk_algo_validity(info.algo);
    if (ret != CMKEM_SUCCEED) {
        km_err_msg(kt->kmgr.err, "%s", get_cmkem_errmsg(ret));
        return NULL;
    }

    ret = check_cmk_id_validity(info);
    if (ret != CMKEM_SUCCEED) {
        km_err_msg(kt->kmgr.err, "%s", get_cmkem_errmsg(ret));
        return NULL;
    }

    ret = check_cmk_entity_validity(info);
    if (ret != CMKEM_SUCCEED) {
        km_err_msg(kt->kmgr.err, "%s", get_cmkem_errmsg(ret));
        return NULL;
    }

    return km_strdup("active");
}

KmUnStr ktool_mk_encrypt(KeyMgr *kmgr, KeyInfo info, KmUnStr plain)
{
    GsKtoolMgr *kt = (GsKtoolMgr *)(void *)kmgr;
    CmkemErrCode ret = CMKEM_SUCCEED;
    KmUnStr cipher = {0};
    CmkemUStr _plain = {plain.val, plain.len};
    CmkemUStr *_cipher = NULL;

    CmkemUStr *cmk_plain = NULL;
    unsigned int cmk_id = 0;
    AlgoType cmk_algo = get_algo_by_str(info.algo);

    get_cmk_id_from_key_path(info.id, &cmk_id);

    ret = read_cmk_plain(kt, cmk_id, &cmk_plain);
    if (ret != CMKEM_SUCCEED) {
        km_err_msg(kt->kmgr.err, "%s", get_cmkem_errmsg(ret));
        return cipher;
    }

    ret = encrypt_with_symm_algo(cmk_algo, &_plain, cmk_plain, &_cipher);
    free_cmkem_ustr_with_erase(cmk_plain);
    if (ret != CMKEM_SUCCEED) {
        km_err_msg(kt->kmgr.err, "%s", get_cmkem_errmsg(ret));
        return cipher;
    }

    cipher.val = _cipher->ustr_val;
    cipher.len = _cipher->ustr_len;
    km_free(_cipher);
    return cipher;
}

KmUnStr ktool_mk_decrypt(KeyMgr *kmgr, KeyInfo info, KmUnStr cipher)
{
    GsKtoolMgr *kt = (GsKtoolMgr *)(void *)kmgr;
    CmkemErrCode ret = CMKEM_SUCCEED;
    KmUnStr plain = {0};
    CmkemUStr _cipher = {cipher.val, cipher.len};
    CmkemUStr *_plain = NULL;

    CmkemUStr *cmk_plain = NULL;
    unsigned int cmk_id = 0;
    AlgoType cmk_algo = get_algo_by_str(info.algo);

    get_cmk_id_from_key_path(info.id, &cmk_id);

    ret = read_cmk_plain(kt, cmk_id, &cmk_plain);
    if (ret != CMKEM_SUCCEED) {
        km_err_msg(kt->kmgr.err, "%s", get_cmkem_errmsg(ret));
        return plain;
    }

    ret = decrypt_with_symm_algo(cmk_algo, &_cipher, cmk_plain, &_plain);
    free_cmkem_ustr_with_erase(cmk_plain);
    if (ret != CMKEM_SUCCEED) {
        km_err_msg(kt->kmgr.err, "%s", get_cmkem_errmsg(ret));
        return plain;
    }

    plain.val = _plain->ustr_val;
    plain.len = _plain->ustr_len;
    km_free(_plain);
    return plain;
}

KeyMgr* ktool_new(KmErr *err)
{
#ifdef ENABLE_KT
    if (!init_gs_ktool()) {
        km_err_msg(err, "failed to init gs_ktool.");
        return NULL;
    }
#endif
    GsKtoolMgr *ktmgr;

    ktmgr = (GsKtoolMgr *)km_alloc_zero(sizeof(GsKtoolMgr));
    if (ktmgr == NULL) {
        km_err_msg(err, "failed to malloc memory");
    }

    ktmgr->kmgr.err = err;
    ktmgr->cache = init_cmk_cache_list();

    return (KeyMgr *)ktmgr;
}

static void ktool_free(KeyMgr *ktmgr)
{
    if (ktmgr == NULL) {
        return;
    }

    GsKtoolMgr *kt = (GsKtoolMgr *)(void *)ktmgr;
    free_cmk_cache_list(kt->cache);
#ifdef ENABLE_KT
    deinit_gs_ktool();
#endif
    km_free(kt);
}

KeyMethod gs_ktool = {
    "gs_ktool",

    ktool_new, /* kmgr_new */
    ktool_free, /* kmgr_free */
    NULL, /* kmgr_set_arg */

    NULL, /* mk_create */
    NULL, /* mk_delete */
    ktool_mk_select, /* mk_select */
    ktool_mk_encrypt, /* mk_encrypt */
    ktool_mk_decrypt, /* mk_decrypt */

    NULL, /* dk_create */
};
